package com.boot.websocket.message;

import cn.hutool.json.JSONUtil;
import com.boot.websocket.model.MsgDTO;
import jakarta.websocket.*;
import jakarta.websocket.server.ServerEndpoint;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author 知了一笑
 * @date 2024-05-04 15:41
 */
@Slf4j
@Service
@ServerEndpoint("/web/socket/msg")
public class MsgWebSocket {

    private static final  ConcurrentHashMap<String,Session> sessions = new ConcurrentHashMap<>();

    private static final AtomicInteger onlineCount = new AtomicInteger(0);

    /**
     * 建立连接调用的方法
     */
    @OnOpen
    public void onOpen(Session session) {
        String userId = session.getRequestParameterMap().get("userId").get(0);
        // 加入Set中
        sessions.put(userId,session);
        // 在线数增加
        onlineCount.getAndIncrement();
        log.info("session-{},online-count-{}",session.getId(),onlineCount.get());
    }

    /**
     * 客户端消息处理的方法
     */
    @OnMessage
    public void sendMsg(Session sender,String message) throws Exception {
        MsgDTO dto = JSONUtil.toBean(message, MsgDTO.class);
        Session receiver = sessions.get(dto.getUserId());
        if (receiver != null) {
            receiver.getBasicRemote().sendText(dto.getMsg());
        }
    }

    /**
     * 关闭连接调用的方法
     */
    @OnClose
    public void onClose(Session session) {
        String userId = session.getRequestParameterMap().get("userId").get(0);
        // 从Set中删除
        sessions.remove(userId);
        // 在线数减少
        onlineCount.getAndDecrement();
        log.info("session-{},down-line-count-{}",session.getId(),onlineCount.get());
    }

    /**
     * 发生错误调用的方法
     */
    @OnError
    public void onError(Session session, Throwable throwable) throws Exception {
        log.error("Web Stock Error", throwable);
        session.getBasicRemote().sendText(throwable.getMessage());
    }
}
